package org.zjvis.graph.analysis.service.algo;


import java.util.*;
import java.util.concurrent.*;
import java.util.stream.Collectors;

public class LabelPropagation {


    private Map<Object, List<Object>> neighborMap;
    private Map<Object, Integer> labelMap;
    private Integer N;
    private List<Object> idList;
    private Integer ITERATION = 10;
    private Integer SIMPLE_BOUND = 10000;

    public LabelPropagation(Map<Object, List<Object>> neighborMap) {
        this.neighborMap = neighborMap;
        this.N = neighborMap.size();
        this.labelMap = new HashMap<>();
        this.idList = new ArrayList<>(neighborMap.keySet());
        for (int i = 0; i < neighborMap.size(); i++) {
            labelMap.put(idList.get(i), i);
        }
    }

    public void setMaxIter(Integer maxIter) {
        this.ITERATION = maxIter;
    }

    public void execute() throws ExecutionException, InterruptedException {
        if (N > SIMPLE_BOUND) {
            executeParallel();
        } else {
            executeSimple();
        }

        List<Integer> clusterClass = labelMap.values().stream().distinct().collect(Collectors.toList());
        Map<Integer, Integer> helper = new HashMap<>();
        for (int i = 0; i < clusterClass.size(); i++) {
            helper.put(clusterClass.get(i), i +1);
        }
        labelMap.replaceAll((i, v) -> helper.get(labelMap.get(i)));
    }

    public void executeSimple() {
        FindDominantLabel findDominantLabelCalc = new FindDominantLabel();
        int labelChange = 100;
        int iter = 0;
        while (labelChange > 0 && iter < ITERATION) {
            labelChange = 0;
            Collections.shuffle(idList);
            for (int i = 0; i < idList.size(); i++) {
                findDominantLabelCalc.setId(idList.get(i));
                Boolean result = findDominantLabelCalc.call();
                if (result != null && result) {
                    labelChange++;
                }
            }
            iter ++;
        }
    }

    public void executeParallel() throws InterruptedException, ExecutionException {
        int totalThread = 5;
        int labelChange = 100;
        int iter = 0;
        ExecutorService threadPool = Executors.newFixedThreadPool(totalThread);
        List<FindDominantLabel> findDominantLabelCalc = new ArrayList<>(totalThread);
        for (int i = 0; i < totalThread; i++) {
            findDominantLabelCalc.add(new FindDominantLabel());
        }
        while (labelChange > 0 && iter < ITERATION) {
            labelChange = 0;
            Collections.shuffle(idList);
            for (int i = 0; i < idList.size(); i += totalThread) {
                for (int j = 0; j < totalThread; j++) {
                    if ((i + j) < idList.size()) {
                        findDominantLabelCalc.get(j).setId(idList.get(i+j));
                    } else {
                        findDominantLabelCalc.get(j).setId(null);
                    }
                }
                List<Future<Boolean>> result = threadPool.invokeAll(findDominantLabelCalc);
                for (Future<Boolean> booleanFuture : result) {
                    Boolean b = booleanFuture.get();
                    if (b != null && b) {
                        labelChange++;
                        if (labelChange == 1)
                            // System.out.print("once more");
                            break;
                    }
                }
            }
            iter++;
        }
        threadPool.shutdown();
    }

    public Map<Object, Integer> getLabelMap() {
        return labelMap;
    }

    private class FindDominantLabel implements Callable<Boolean> {
        private List<Integer> dominantLabel;
        private List<Integer> labelCount;
        private Object id;

        private FindDominantLabel() {
            dominantLabel = new ArrayList<>();
            labelCount = new ArrayList<>(N);
            for (int i = 0; i < N; i++) {
                labelCount.add(0);
            }
        }

        public void setId(Object id) {
            this.id = id;
        }

        @Override
        public Boolean call() {
            boolean run = false;
            if (id == null) {
                return false;
            }
            dominantLabel.clear();
            Collections.fill(labelCount, 0);
            Integer maxCount = 0;
            List<Object> neighbors = neighborMap.get(id);
            for (Object neighborId: neighbors) {
                Integer neighborLabel = labelMap.get(neighborId);
                Integer neighborLabelCount = labelCount.get(neighborLabel) + 1;
                labelCount.set(neighborLabel, neighborLabelCount);
                if (maxCount < neighborLabelCount) {
                    maxCount = neighborLabelCount;
                    dominantLabel.clear();
                    dominantLabel.add(neighborLabel);
                } else if (maxCount.equals(neighborLabelCount)) {
                    dominantLabel.add(neighborLabel);
                }
            }
            if (dominantLabel.size() > 0) {
                final int  rand = (int) (Math.random() * dominantLabel.size());
                Integer newLabel = dominantLabel.get(rand);
//                Integer newLabel = Collections.max(dominantLabel);
                Integer currentLabel = labelMap.get(id);
                if (!newLabel.equals(currentLabel)) {
                    run = true;
                }
                labelMap.put(id, newLabel);
            }
            return run;
        }

    }
}
